from typing import List

from openai import OpenAI

from .base import BaseEmb


class OpenAIEmb(BaseEmb):
    def __init__(self, model_name: str, api_key: str, base_url: str, **kwargs):
        super().__init__(model_name=model_name, **kwargs)
        self.client = OpenAI(api_key=api_key, base_url=base_url)

    def get_emb(self, text: str) -> List[float]:
        emb = self.client.embeddings.create(
            model=self.model_name,
            input=text,
        )
        return emb.data[0].embedding
